import math
import matplotlib.pyplot as plt
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_trans

ds.set_seed(1)
DATA_DIR_CIFAR10 = './datasets/cifar10/cifar-10-batches-bin/'
# dataset_cifar10 = ds.Cifar10Dataset(DATA_DIR_CIFAR10, num_samples=4)
# dataset1 = ds.Cifar10Dataset(DATA_DIR_CIFAR10, num_samples=4)

# random_crop = c_trans.RandomCrop([10, 10])
# dataset2 = dataset1.map(operations=random_crop, input_columns=['image'])

# random_horizontal_flip = c_trans.RandomHorizontalFlip(prob=0.8)
# dataset2=dataset1.map(operations=random_horizontal_flip,input_columns=['image'])


dataset1 = ds.Cifar10Dataset(DATA_DIR_CIFAR10, num_samples=4, shuffle=True)
# resize = c_trans.Resize(size=[101, 101])
# dataset2 = dataset1.map(operations=resize, input_columns=['image'])
invert = c_trans.Invert()
dataset2 = dataset1.map(operations=invert, input_columns=['image'])


def printDataset(dataset_list, name_list):
    dataset_sizes = []
    for dataset in dataset_list:
        dataset_sizes.append(dataset.get_dataset_size())
    row = len(dataset_list)
    column = max(dataset_sizes)
    pos = 1
    for i in range(row):
        for data in dataset_list[i].create_dict_iterator(output_numpy=True):
            plt.subplot(row, column, pos)
            plt.imshow(data['image'])
            plt.title(data['label'])
            print(name_list[i], " shape:", data['image'].shape, 'label:', data['label'])

            pos = pos + 1
        pos = column * (i + 1) + 1

    plt.show()


printDataset([dataset1, dataset2], ['source image', 'cropped image'])

# for data in dataset_cifar10.create_dict_iterator():
#     print('Image shape:', data['image'].shape, ',label:', data['label'])
#
# dataset = dataset_cifar10.batch(batch_size=4, drop_remainder=False)

# print('after batch:')
#
# for data in dataset.create_dict_iterator():
#     print('Image shape:', data['image'].shape, ',label:', data['label'])
#
#
# def plt_result(dataset, row):
#     num = 1
#     for data in dataset.create_dict_iterator():
#         for i in range(4):
#             plt.subplot(row, math.ceil(8 / row), num)
#             image = data['image'].asnumpy()
#             label = data['label'].asnumpy()
#
#             plt.title(label[i])
#             plt.imshow(image[i], interpolation="None")
#             num += 1
#
#         plt.show()

# plt_result(dataset,2)

# data_cifar10 = ds.Cifar10Dataset(DATA_DIR_CIFAR10, num_samples=4)
# for data in dataset_cifar10.create_dict_iterator():
#     print('Image shape:', data['image'].shape, ',label:', data['label'])
#
# print('after repat.....')
# d = data_cifar10.repeat(count=2)
# for data in d.create_dict_iterator():
#     print('Image shape:', data['image'].shape, ',label:', data['label'])
